{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 使用精确的提示词完成符合格式输出的文本分类任务\n",
    "\n",
    "在本章节中，我们将带领开发者体验在 `文本分类` 任务中，在不进行任何模型微调的情况下，使用具有微小差异的提示词，并使用 `ChatGLM3-6B` 模型来完成符合格式输出的文本分类任务。\n",
    "\n",
    "我们使用 [新闻标题分类](https://github.com/fateleak/toutiao-multilevel-text-classfication-dataset) 任务来体验模型的表现。这是一个经典的文本分类任务，我们将使用 `新闻信息` 作为输入，模型需要预测出这个标题属于哪个类别。\n",
    "\n",
    "由于 `ChatGLM3-6B` 强大的能力，我们可以直接使用 `新闻标题` 作为输入，而不需要额外的信息，也不需要进行任何模型微调，就能完成这个任务。我们的目标是，让模型能够成功输出原始数据集中的15种类别中的一种作为分类结果，而且不能输入任何冗余的对话。\n",
    "\n",
    "在本章节中，用户将直观的对比两种不同细粒度的提示词下对模型分类造成的影响。\n",
    "\n",
    "## 硬件要求\n",
    "本实践手册需要使用 FP16 精度的模型进行推理，因此，我们推荐使用至少 16GB 显存的 英伟达 GPU 来完成本实践手册。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "引入必要的库"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [],
   "source": [
    "from transformers import AutoModel, AutoTokenizer\n",
    "import torch"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:03.002351Z",
     "end_time": "2023-11-23T19:23:03.016701Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "设置好对应的参数，以保证模型推理的公平性。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [],
   "source": [
    "max_new_tokens = 1024\n",
    "temperature = 0.1\n",
    "top_p = 0.9\n",
    "device = \"cuda\"\n",
    "model_path_chat = \"THUDM/chatglm3-6b\""
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:03.004850Z",
     "end_time": "2023-11-23T19:23:03.047154Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [
    {
     "data": {
      "text/plain": "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d6468f7889994e638ac754fde95b6e58"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenizer_path_chat = model_path_chat\n",
    "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path_chat, trust_remote_code=True, encode_special_tokens=True)\n",
    "model = AutoModel.from_pretrained(model_path_chat, load_in_8bit=False, trust_remote_code=True).to(device)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:03.047154Z",
     "end_time": "2023-11-23T19:23:13.086148Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [],
   "source": [
    "def answer(prompt):\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "    response = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=max_new_tokens, history=None,\n",
    "                              temperature=temperature,\n",
    "                              top_p=top_p, do_sample=True)\n",
    "    response = response[0, inputs[\"input_ids\"].shape[-1]:]\n",
    "    answer = tokenizer.decode(response, skip_special_tokens=True)\n",
    "    return answer"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:13.086148Z",
     "end_time": "2023-11-23T19:23:13.086148Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "在本样例中，模型应该输出的标准答案应该为:\n",
    "'news_sports'"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "outputs": [],
   "source": [
    "PROMPT1 = \"\"\"\n",
    "<|system|>\n",
    "你是一个专业的新闻专家，请根据我提供的新闻信息，包括新闻标题，新闻关键词等信息，你需要在这些类中选择其中一个，它们分别是:\n",
    "news_story\n",
    "news_culture\n",
    "news_sports\n",
    "news_finance\n",
    "news_house\n",
    "news_car\n",
    "news_edu\n",
    "news_tech\n",
    "news_military\n",
    "news_travel\n",
    "news_world\n",
    "stock\n",
    "news_agriculture\n",
    "news_game\n",
    "这是我的信息:\n",
    "<|user|>\n",
    "新闻标题: 女乒今天排兵布阵不合理，丁宁昨天刚打硬仗，今天应该打第一单打，你认为呢？\n",
    "新闻关键词: 无\n",
    "<|assistant|>\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:13.086148Z",
     "end_time": "2023-11-23T19:23:13.095864Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [
    {
     "data": {
      "text/plain": "' 新闻类型：体育新闻\\n  新闻关键词：女乒，排兵布阵，丁宁，硬仗，第一单打'"
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "answer(PROMPT1)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:13.095864Z",
     "end_time": "2023-11-23T19:23:14.034154Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "很明显，使用这个提示词没有完成要求。我们更换一个提示词，经过优化后，看是否能达标。\n",
    "\n",
    "我们为提示词中设定了更多的限定词汇，并用Markdown语法规范化了输出的格式。修改后的提示词如下："
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [],
   "source": [
    "PROMPT2 = \"\"\"\n",
    "<|system|>\n",
    "请根据我提供的新闻信息，格式为\n",
    "```\n",
    "新闻标题: xxx\n",
    "新闻关键词: xxx\n",
    "```\n",
    "你要对每一行新闻类别进行分类并告诉我结果，不要返回其他信息和多于的文字，这些类别是:\n",
    "news_story\n",
    "news_culture\n",
    "news_sports\n",
    "news_finance\n",
    "news_house\n",
    "news_car\n",
    "news_edu\n",
    "news_tech\n",
    "news_military\n",
    "news_travel\n",
    "news_world\n",
    "stock\n",
    "news_agriculture\n",
    "news_game\n",
    "我将为你提供一些新闻标题和关键词，你需要根据这些信息，对这些新闻进行分类，每条一行，格式为\n",
    "```\n",
    "类别中的其中一个\n",
    "```\n",
    "不要返回其他内容，例如新闻标题，新闻关键词等等，只需要返回分类结果即可\n",
    "<|user|>\n",
    "新闻标题: 女乒今天排兵布阵不合理，丁宁昨天刚打硬仗，今天应该打第一单打，你认为呢？\n",
    "新闻关键词: 无\n",
    "<|assistant|>\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:14.044595Z",
     "end_time": "2023-11-23T19:23:14.044595Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "outputs": [
    {
     "data": {
      "text/plain": "'news_sports'"
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "answer(PROMPT2)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:14.044595Z",
     "end_time": "2023-11-23T19:23:14.217801Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "这一次，模型成功给出了理想的答案。我们结束实操训练，删除模型并释放显存。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "outputs": [],
   "source": [
    "del model\n",
    "torch.cuda.empty_cache()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-11-23T19:23:14.217801Z",
     "end_time": "2023-11-23T19:23:14.373137Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 总结\n",
    "在本实践手册中，我们让开发者体验了使用不同提示词下，`ChatGLM3-6B` 模型在 `新闻标题分类` 任务中的表现。\n",
    "1. 模型在两次提示中都能正确的分类，验证了模型底座的能力。\n",
    "2. 在使用第二个提示词时，模型的输出符合格式要求，而且没有冗余的对话。符合我们的要求。\n",
    "3. 第二个提示词使用了一定的表示限定的词汇和更准确的任务要求，包括规定了返回格式，返回的行书等的。\n",
    "\n",
    "因此，通过上述内容，我们可以发现：\n",
    "对于有格式要求的任务，可以使用更具有格式化的提示词来完成任务。同时，使用更低的 `temperature` 和更高的 `top_p` 可以提高模型的输出质量。\n",
    "\n",
    "模型在训练中大量使用 Markdown格式。因此，用 ``` 符号来规范提示词将能提升模型输出格式化数据的准确率。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
